import torch.nn as nn
import torch.nn.functional as F
#from dgl.nn.pytorch import GraphConv
from torch_geometric.nn import GCNConv
import scipy
import numpy as np
import torch
g_seed=39788
torch.manual_seed(g_seed)
from models.Norm.norm import Norm
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
torch.cuda.manual_seed_all(g_seed)
class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, activation, node_s0, node_s1, fairness,layer_num, norm_type):
        super(GCN, self).__init__()
        self.body = GCN_Body(nfeat, nhid, dropout, activation, node_s0, node_s1, fairness, layer_num, norm_type)
        self.fc = nn.Linear(nhid,nclass)

    def forward(self, g, x):
        x = self.body(g,x)
        x = self.fc(x)
        return x

# def GCN(nn.Module):
class GCN_Body(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, activation, node_s0, node_s1, fairness, dlayer, k: int = 2, norm_type='gn', gpu=0):
        super(GCN_Body, self).__init__()
        self.activation=activation
        self.gc1 = GCNConv(nfeat, nhid)
        self.gc2 = GCNConv(nhid, nhid)
        self.gc3 = GCNConv(nhid, nhid)
        self.dropout = nn.Dropout(dropout)
        self.fairness = fairness
        self.dlayer = dlayer
        self.nt=norm_type
        self.node_s0=node_s0
        self.node_s1=node_s1
        self.fc = nn.Linear(nhid,nclass)
        self.adv= nn.Linear(nhid,1)
        self.gpu_id= gpu
        base_model= GCNConv
        
        if fairness or dlayer:
            self.norms1=torch.nn.ModuleList()
            self.norms2=torch.nn.ModuleList()
        else:
            self.norms = torch.nn.ModuleList()
        assert k >= 2
        self.k = k

        if fairness or dlayer:
            self.conv = [base_model(nfeat, nhid)]
            self.norms1.append(Norm(norm_type, nhid))
            self.norms2.append(Norm(norm_type, nhid))
            for _ in range(1, k-1):
                self.conv.append(base_model(nhid, nhid))
                self.norms1.append(Norm(norm_type, nhid))
                self.norms2.append(Norm(norm_type, nhid))
            self.conv.append(base_model(nhid, int(nhid)))
            self.norms1.append(Norm(norm_type, int(nhid)))
            self.norms2.append(Norm(norm_type, int(nhid)))
        else:
            self.conv = [base_model(nfeat, int((nhid)))]
            self.norms.append(Norm(norm_type, nhid))
            for _ in range(1, k-1):
                self.conv.append(base_model(nhid, nhid))
                self.norms.append(Norm(norm_type, nhid))
            self.conv.append(base_model(nhid, nhid))
            self.norms.append(Norm(norm_type, nhid))
        self.conv = nn.ModuleList(self.conv)
        if fairness or dlayer:
            G_params = list(self.conv.parameters()) + list(self.norms1.parameters()) + list(self.norms2.parameters()) + list(self.fc.parameters())
        else:
            G_params = list(self.conv.parameters()) + list(self.norms.parameters()) + list(self.fc.parameters())
        self.optimizer_G = torch.optim.Adam(G_params, lr = 0.001, weight_decay = 1e-5)
        self.A_loss = 0
        self.G_loss = 0
        self.optimizer_A = torch.optim.Adam(self.adv.parameters(), lr = 0.001, weight_decay = 1e-5)
        self.criterion = nn.BCEWithLogitsLoss()
    def forward(self, x, edges):
        if self.fairness:
            means1=torch.Tensor().cuda(device=self.gpu_id)
            means2=torch.Tensor().cuda(device=self.gpu_id)
            stds1=torch.Tensor().cuda(device=self.gpu_id)
            stds2=torch.Tensor().cuda(device=self.gpu_id)
            all_std=torch.Tensor().cuda(device=self.gpu_id)
            if self.nt=='bn':
                for i in range(self.k):
                    h=self.conv[i](x, edges)
                    h[self.node_s0,:], bias1, weight1=self.norms1[i](h[self.node_s0,:])
                    h[self.node_s1,:], bias2, weight2=self.norms2[i](h[self.node_s1,:])
                    #means1=torch.cat((self.norms1[i].bias(),means1),0)                                                                                                                                                                                    
                    #means2=torch.cat((self.norms2[i].bias(),means2),0)                                                                                                                                                                                    
                    #print('bias1 size: ',bias1.size(dim=0))                                                                                                                                                                                               
                    means1=torch.cat((torch.reshape(bias1,(1,bias1.size(dim=0))),means1),0)
                    means2=torch.cat((torch.reshape(bias2,(1,bias2.size(dim=0))),means2),0)
                    stds1=torch.cat((torch.reshape(weight1,(1,weight1.size(dim=0))),stds1),0)
                    stds2=torch.cat((torch.reshape(weight2,(1,weight2.size(dim=0))),stds2),0)
                    #stds1=torch.cat((self.norms1[i].weight,stds1),0)                                                                                                                                                                                      
                    #stds2=torch.cat((self.norms2[i].weight,stds2),0)                                                                                                                                                                                      
                    all_std=torch.cat((torch.reshape(torch.std(h,0),(1, torch.std(h,0).size(dim=0))),all_std),0)
                    x=self.activation(h)#h#self.activation(h)
                #print('means1 size: ',means1.size())                                                                                                                                                                                                                       #print('std1 size: ',stds1.size())                                                                                                                                                                                                        
                return self.fc(x), x, means1, means2, stds1, stds2, all_std
            elif self.nt=='gn':
                for i in range(self.k):
                    h=self.conv[i](x, edges)
                    h[self.node_s0,:],me,st=self.norms1[i](h[self.node_s0,:])
                    h[self.node_s1,:],me2,st2=self.norms2[i](h[self.node_s1,:])
                    means1=torch.cat((me,means1),0)
                    means2=torch.cat((me2,means2),0)
                    stds1=torch.cat((st,stds1),0)
                    stds2=torch.cat((st2,stds2),0)
                    all_std=torch.cat((torch.std(h,0),all_std),0)
#                    print('means1 size: ',means1.size())                                                                                                                                                                                                  
#                    print('std1 size: ',std1.size())                                                                                                                                                                                                      
                    x=self.activation(h)#h#self.activation(h)
                return self.fc(x), x, means1, means2, stds1, stds2, all_std
        elif self.dlayer:
            if self.nt=='bn':
                for i in range(self.k):
                    h=self.conv[i](x, edges)
                    h[self.node_s0,:], bias1, weight1=self.norms1[i](h[self.node_s0,:])
                    h[self.node_s1,:], bias2, weight2=self.norms2[i](h[self.node_s1,:])
                    x=self.activation(h)#h
                return self.fc(x),x
            elif self.nt=='gn':
                for i in range(self.k):
                    h=self.conv[i](x, edges)
                    h[self.node_s0,:],me,st=self.norms1[i](h[self.node_s0,:])
                    h[self.node_s1,:],me2,st2=self.norms2[i](h[self.node_s1,:])
                    x=self.activation(h)#h
                return self.fc(x),x
        else:
            if self.nt=='bn':
                for i in range(self.k):
                    x,m1,w1=self.norms[i](self.conv[i](x, edges))
                    x = self.activation(x)
                return self.fc(x),x
            elif self.nt=='gn':
                for i in range(self.k):
                    x,me,st=self.norms[i](self.conv[i](x, edges))
                    x = self.activation(x)
                return self.fc(x),x
            else:
                for i in range(self.k):
                    x=self.norms[i](self.conv[i](x, edges))
                    x = self.activation(x)
                return self.fc(x),x
            
    #def adversarial_optimize(self, h, s_score):
    #    s_g = self.adv(h)#

#        self.adv_loss = self.criterion(s_g,s_score)
        #s_score[idx_sens_train]=sens[idx_sens_train].unsqueeze(1).float()

 #       self.optimizer_A.zero_grad()


    def optimize(self, h, sens, idx_train, y, labels, means1, means2, stds1, stds2, all_std, idx_s0, idx_s1, hp1, hp2, kappa, eta, hp3,  adversarial=False):
        self.train()        
        ### update E, G                                                                                                                                                                                                                     
        self.optimizer_G.zero_grad()
        self.adv.requires_grad_(False)
        s_g = self.adv(h)

        #s_score = torch.sigmoid(s.detach())
        # s_score = (s_score > 0.5).float()                                                                                                                                                                                                                 
        #s_score[idx_sens_train]=sens[idx_train].unsqueeze(1).float()
        y_score = torch.sigmoid(y)
        if kappa==0:
            self.cov=0
        else:
            self.cov = torch.abs(torch.mean((sens[idx_train] - torch.mean(sens[idx_train])) * (y_score[idx_train] - torch.mean(y_score[idx_train]))))
        self.cls_loss = self.criterion(y[idx_train],labels[idx_train].unsqueeze(1).float())
        if eta==0:
            self.adv_loss=0
        else:
            self.adv_loss = self.criterion(s_g[idx_train],sens[idx_train].unsqueeze(1).float())
        self.fair_norm= fairness_loss(means1, means2, stds1, stds2, all_std, hp1, hp2, sum(idx_s0), sum(idx_s1),self.activation)
        if hp3==0:
            self.relaxed=0
        else:
            self.relaxed=fair_metrics(y, sens, idx_train, idx_s0, idx_s1)
        self.G_loss = self.cls_loss  + self.fair_norm + kappa*self.cov + hp3*self.relaxed - eta * self.adv_loss
        self.G_loss.backward()
        self.optimizer_G.step()
        if adversarial:
            ## update Adv                                                                                                                                                                                                                                    
            self.adv.requires_grad_(True)
            self.optimizer_A.zero_grad()
            s_g = self.adv(h.detach())
            self.A_loss = self.criterion(s_g[idx_train],sens[idx_train].unsqueeze(1).float())
            self.A_loss.backward()
            self.optimizer_A.step()

    def optimize_slayer(self, h, sens, idx_train, y, labels, idx_s0, idx_s1, hp1, hp2, kappa, eta, hp3,  adversarial=False):
        self.train()
        ### update E, G                                                                                                                                                                                                                                      
        self.optimizer_G.zero_grad()
        self.adv.requires_grad_(False)
        s_g = self.adv(h)
        #s_score = torch.sigmoid(s.detach())                                                                                                                                                                                                                 
        # s_score = (s_score > 0.5).float()                                                                                                                                                                                                                  
        #s_score[idx_sens_train]=sens[idx_train].unsqueeze(1).float()                                                                                                                                                                                        
        y_score = torch.sigmoid(y)
        if kappa==0:
            self.cov=0
        else:
            self.cov = torch.abs(torch.mean((sens[idx_train] - torch.mean(sens[idx_train])) * (y_score[idx_train] - torch.mean(y_score[idx_train]))))
        self.cls_loss = self.criterion(y[idx_train],labels[idx_train].unsqueeze(1).float())
        if eta==0:
            self.adv_loss=0
        else:
            self.adv_loss = self.criterion(s_g[idx_train],sens[idx_train].unsqueeze(1).float())
        #self.fair_norm= fairness_loss(means1, means2, stds1, stds2, all_std, hp1, hp2)
        if hp3==0:
            self.relaxed=0
        else:
            self.relaxed=fair_metrics(y, sens, idx_train, idx_s0, idx_s1)
        self.G_loss = self.cls_loss  + kappa*self.cov + hp3*self.relaxed - eta * self.adv_loss
        self.G_loss.backward()
        self.optimizer_G.step()
        if adversarial:
            ## update Adv                                                                                                                                                                                                                                    
            self.adv.requires_grad_(True)
            self.optimizer_A.zero_grad()
            s_g = self.adv(h.detach())
            self.A_loss = self.criterion(s_g[idx_train],sens[idx_train].unsqueeze(1).float())
            self.A_loss.backward()
            self.optimizer_A.step()
        

def fairness_loss(means1: torch.Tensor, means2: torch.Tensor, stds1: torch.Tensor, stds2: torch.Tensor, all_std: torch.Tensor, hp1, hp2, len_s0, len_s1,act):
    dims=means1.size()
#   print('overall mean size: ', dims)                                                                                                                                                                                                                 
    #print('check all_std shape: ',all_std[0].size())
    #print('check mean size: ',means1[0].size())
    #sz=torch.sqrt((len_s0*torch.square(stds1)+len_s1*torch.square(stds2))/(len_s0 + len_s1) + (len_s0*len_s1*torch.square(means1-means2))/(torch.square(len_s0 + len_s1)))
    loss1=torch.square(torch.norm((means1-means2)))
    loss2=torch.square(torch.norm(stds1))+torch.square(torch.norm(stds2))
    
    print('mean loss: ',hp1*loss1)
    print('std loss: ', hp2*loss2)
    return hp1*loss1+hp2*loss2



def fair_metrics(output, sens, idx_train, idx_s0, idx_s1):
    c=3
    m=nn.Tanh()
    
    t=m(c*torch.clamp(output, min=0, max=None))
    t=t[idx_train]
    #y_train = labels[idx_train]
    #idx_s0 = sens[idx_train]==0
    #idx_s1 = sens[idx_train]==1
    #torch.pow(((torch.sum(t[idx_s0]))/sum(idx_s0)) - ((torch.sum(t[idx_s1]))/sum(idx_s1)),2)
    
    return torch.pow(((torch.sum(t[idx_s0]))/sum(idx_s0)) - ((torch.sum(t[idx_s1]))/sum(idx_s1)),2)
